﻿//---------------------------------------------------------------------------------
// Microsoft (R) .NET Services SDK
// Software Development Kit
// 
// Copyright (c) Microsoft Corporation. All rights reserved.  
//
// THIS CODE AND INFORMATION ARE PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND, 
// EITHER EXPRESSED OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED WARRANTIES 
// OF MERCHANTABILITY AND/OR FITNESS FOR A PARTICULAR PURPOSE. 
//---------------------------------------------------------------------------------

namespace Microsoft.AccessControl.SDK.ACSCalculator.AuthorizationManager
{
    using System;
    using System.Collections.Generic;
    using System.Linq;
    using System.Security.Cryptography;
    using System.Text;
    using System.Web;
    using System.Globalization;

    public class TokenValidator
    {
        private const string issuerLabel = "Issuer";
        private const string expiresLabel = "ExpiresOn";
        private const string audienceLabel = "Audience";
        private const string hmacSHA256Label = "HMACSHA256";

        private string acsHostName;

        private byte[] trustedSigningKey;
        private Uri trustedTokenIssuer;
        private Uri trustedAudienceValue;

        public TokenValidator(string acsHostName, string trustedSolution, string trustedAudienceValue, byte[] trustedSigningKey)
        {
            this.acsHostName = acsHostName;
            this.trustedSigningKey = trustedSigningKey;

            this.trustedTokenIssuer = new Uri(string.Format(
                CultureInfo.InvariantCulture,
                "https://{0}.{1}/WRAPv0.8", 
                trustedSolution, 
                acsHostName));
            
            this.trustedAudienceValue = new Uri(trustedAudienceValue);
        }


        public bool Validate(string token)
        {
            if (!this.IsHMACValid(token, this.trustedSigningKey)) 
            { 
                return false; 
            }

            if (this.IsExpired(token)) 
            { 
                return false; 
            }

            if (!this.IsIssuerTrusted(token)) 
            { 
                return false; 
            }

            if (!this.IsAudienceTrusted(token)) 
            { 
                return false; 
            }

            return true;
        }

        public Dictionary<string, string> GetNameValues(string token)
        {
            if (string.IsNullOrEmpty(token))
            {
                throw new ArgumentException();
            }

            return
                token
                .Split('&')
                .Aggregate(
                new Dictionary<string, string>(), 
                (dict, rawNameValue) =>
                {
                    if (rawNameValue == string.Empty)
                    {
                        return dict;
                    }

                    string[] nameValue = rawNameValue.Split('=');

                    if (nameValue.Length != 2)
                    {
                        throw new ArgumentException("Invalid formEncodedstring - contains a name/value pair missing an = character", "token");
                    }

                    if (dict.ContainsKey(nameValue[0]) == true)
                    {
                        throw new ArgumentException("Repeated name/value pair in form", "token");
                    }

                    dict.Add(HttpUtility.UrlDecode(nameValue[0]), HttpUtility.UrlDecode(nameValue[1]));
                    return dict;
                });
        }

        private static ulong GenerateTimeStamp()
        {
            // Default implementation of epoch time
            TimeSpan ts = DateTime.UtcNow - new DateTime(1970, 1, 1, 0, 0, 0, 0);
            return Convert.ToUInt64(ts.TotalSeconds);
        }

        private bool IsAudienceTrusted(string token)
        {
            Dictionary<string, string> tokenValues = this.GetNameValues(token);

            string audienceValue;

            tokenValues.TryGetValue(audienceLabel, out audienceValue);
            
            if (!string.IsNullOrEmpty(audienceValue))
            {
                Uri audienceValueUri = new Uri(audienceValue);
                if (audienceValueUri.Equals(this.trustedAudienceValue))
                {
                    return true;
                }
            }

            return false;
        }

        private bool IsIssuerTrusted(string token)
        {
            Dictionary<string, string> tokenValues = this.GetNameValues(token);

            string issuerName;

            tokenValues.TryGetValue(issuerLabel, out issuerName);
            
            if (!string.IsNullOrEmpty(issuerName))
            {
                Uri issuerNameUri = new Uri(issuerName);
                if (issuerNameUri.Equals(this.trustedTokenIssuer))
                {
                    return true;
                }
            }

            return false;
        }

        private bool IsHMACValid(string swt, byte[] sha256HMACKey)
        {
            string[] swtWithSignature = swt.Split(new string[] { "&" + hmacSHA256Label + "=" }, StringSplitOptions.None);

            if ((swtWithSignature == null) || (swtWithSignature.Length != 2))
            {
                return false;
            }

            HMACSHA256 hmac = new HMACSHA256(sha256HMACKey);

            byte[] locallyGeneratedSignatureInBytes = hmac.ComputeHash(Encoding.ASCII.GetBytes(swtWithSignature[0]));

            string locallyGeneratedSignature = HttpUtility.UrlEncode(Convert.ToBase64String(locallyGeneratedSignatureInBytes));

            return locallyGeneratedSignature == swtWithSignature[1];
        }

        private bool IsExpired(string swt)
        {
            try
            {
                Dictionary<string, string> nameValues = this.GetNameValues(swt);
                string expiresOnValue = nameValues[expiresLabel];
                ulong expiresOn = Convert.ToUInt64(expiresOnValue);
                ulong currentTime = Convert.ToUInt64(GenerateTimeStamp());

                if (currentTime > expiresOn) 
                { 
                    return true; 
                }

                return false;
            }
            catch (KeyNotFoundException)
            {
                throw new ArgumentException();
            }
        }
    }
}

